{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04191.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04191.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bieroHZU6DMw",
        "colab_type": "text"
      },
      "source": [
        "## 5.4 池化层"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8SxlYyeP6bnL",
        "colab_type": "text"
      },
      "source": [
        "### 5.4.1 二维最大池化层和平均池化层"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8maVP_gZ6kad",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch \n",
        "from torch import nn \n",
        "\n",
        "\n",
        "def pool2d(X, pool_size, mode='max'):\n",
        "    X = X.float()\n",
        "    p_h, p_w = pool_size \n",
        "    Y = torch.zeros(X.shape[0] - p_h + 1, X.shape[1] - p_w + 1)\n",
        "    for i in range(Y.shape[0]):\n",
        "        for j in range(Y.shape[1]):\n",
        "            if mode == 'max':\n",
        "                Y[i, j] = X[i: i + p_h, j: j + p_w].max()\n",
        "            elif mode == 'avg':\n",
        "                Y[i, j] = X[i: i + p_h, j: j + p_w].mean()\n",
        "\n",
        "    return Y"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x-O0cHXN7dxs",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "995a40d6-dabe-42d0-ea99-8d6f4f5b51f7"
      },
      "source": [
        "X = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])\n",
        "pool2d(X, (2, 2))"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[4., 5.],\n",
              "        [7., 8.]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9UD0scjR7n3d",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "3581e5b0-93b5-4728-d3c2-6051c8f10374"
      },
      "source": [
        "pool2d(X, (2, 2), 'avg')"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[2., 3.],\n",
              "        [5., 6.]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KoPYTAbC7ura",
        "colab_type": "text"
      },
      "source": [
        "### 5.4.2 填充和步幅"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bv4m6JUC8grJ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "65508eed-c08d-434c-d14c-953cc8ad6c88"
      },
      "source": [
        "X = torch.arange(16, dtype=torch.float).view((1, 1, 4, 4))\n",
        "X"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[[[ 0.,  1.,  2.,  3.],\n",
              "          [ 4.,  5.,  6.,  7.],\n",
              "          [ 8.,  9., 10., 11.],\n",
              "          [12., 13., 14., 15.]]]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vLV3fD7A8qFh",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "b6876d85-99d4-4014-b2cb-04d34bfedd01"
      },
      "source": [
        "pool2d = nn.MaxPool2d(3)\n",
        "pool2d(X)"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[[[10.]]]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JD9-gm2v8-xX",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "0fc1e178-aee1-4024-e1ec-947a2009f95b"
      },
      "source": [
        "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n",
        "pool2d(X)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[[[ 5.,  7.],\n",
              "          [13., 15.]]]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lOsFxJb79Hgx",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 67
        },
        "outputId": "91102f5e-2122-48d5-f626-a3cc49651b74"
      },
      "source": [
        "pool2d = nn.MaxPool2d((2, 4), padding=(1, 2), stride=(2, 3))\n",
        "pool2d(X)"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[[[ 1.,  3.],\n",
              "          [ 9., 11.],\n",
              "          [13., 15.]]]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eLHlnsPN9R0A",
        "colab_type": "text"
      },
      "source": [
        "### 5.4.3 多通道"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iMgv80_89VIv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 168
        },
        "outputId": "f50b5f4a-1c5f-441b-e9ac-976481f52c1d"
      },
      "source": [
        "X = torch.cat((X, X + 1), dim=1)\n",
        "X "
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[[[ 0.,  1.,  2.,  3.],\n",
              "          [ 4.,  5.,  6.,  7.],\n",
              "          [ 8.,  9., 10., 11.],\n",
              "          [12., 13., 14., 15.]],\n",
              "\n",
              "         [[ 1.,  2.,  3.,  4.],\n",
              "          [ 5.,  6.,  7.,  8.],\n",
              "          [ 9., 10., 11., 12.],\n",
              "          [13., 14., 15., 16.]]]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5U8yg0Ya9cPW",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "cecdd25a-fe38-473b-baba-f119a9de58ef"
      },
      "source": [
        "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n",
        "pool2d(X)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[[[ 5.,  7.],\n",
              "          [13., 15.]],\n",
              "\n",
              "         [[ 6.,  8.],\n",
              "          [14., 16.]]]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    }
  ]
}